{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SubPopulation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This page gives a quick overview on how to start using subpopulation,a subclass of textflint methods to verify the robustness comprehensively. The full list of `SubPopulation`s can be found in our [website](https://www.textflint.com) or [github](https://github.com/textflint/textflint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## How to use a built-in SubPopulation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "textflint offers multiple universal `SubPopulation` methods for nlp tasks and we will provide task-specific `SubPopulation` methods in the coming version. Here we use the `LM` Subpopulation on Sentiment Analysis task to give a brief introduction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34;1mtextflint\u001b[0m: ******Start load!******\n",
      "100%|██████████| 3/3 [00:00<00:00, 1950.23it/s]\n",
      "\u001b[34;1mtextflint\u001b[0m: 3 in total, 3 were loaded successful.\n",
      "\u001b[34;1mtextflint\u001b[0m: ******Finish load!******\n",
      "100%|██████████| 3/3 [00:46<00:00, 15.59s/it]\n"
     ]
    }
   ],
   "source": [
    "# 1. Import the SA Sample, textflint Dataset and LM SubPopulation method\n",
    "from textflint.input_layer.component.sample.sa_sample import SASample\n",
    "from textflint.input_layer.dataset import Dataset\n",
    "from textflint.generation_layer.subpopulation.UT import LMSubPopulation\n",
    "\n",
    "# 2. Initialize the SA Sample\n",
    "sample1 = {'x': 'Titanic is my favorite movie.','y': 'pos'}\n",
    "sample2 = {'x': 'I don\\'t like the actor Tim Hill', 'y': 'neg'}\n",
    "sample3 = {'x': 'The leading actor is good.','y': 'pos'}\n",
    "samples = [sample1, sample2, sample3]\n",
    "\n",
    "# 3. Construct the Dataset\n",
    "dataset = Dataset('SA')\n",
    "dataset.load(samples)\n",
    "\n",
    "# 4. Define the SubPopulation\n",
    "sub = LMSubPopulation(intervals=[0, 1])\n",
    "\n",
    "# 5. Run SubPopulation on Dataset\n",
    "sub_dataset = sub.slice_population(dataset, 'x')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can save the sub-dataset in a json flie in predefined path dir through `Dataset.save_json` interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34;1mtextflint\u001b[0m: Save samples to ./test_result/result.json!\n"
     ]
    }
   ],
   "source": [
    "# output path\n",
    "path = './test_result/'\n",
    "sub_dataset.save_json(path+ 'result.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['result.json']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "os.listdir(path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define your own SubPopulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from textflint.generation_layer.subpopulation import SubPopulation\n",
    "\n",
    "class LengthStr(SubPopulation):\n",
    "    r\"\"\"\n",
    "    Filter samples based on string length\n",
    "    \"\"\"\n",
    "    \n",
    "    def _score(self, sample, fields, **kwargs):\n",
    "        r\"\"\"\n",
    "        Score the sample\n",
    "\n",
    "        :param sample: data sample\n",
    "        :param list fields: list of field str\n",
    "        :param kwargs:\n",
    "        :return int: score for sample\n",
    "        \"\"\"\n",
    "        return len(sample.get_text(fields[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `SubPopulation` requires you to reimplement the abstractive method `_score`, used to assign a score to `Sample`. The above code box define a new `SubPopulation` method `LengthStr` with a score representing the length of string.\n",
    "\n",
    "The `fields` here is a list of field names of the input `Sample`, and we compute the score based on the values of these specific fields."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 6654.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'x': 'The leading actor is good.', 'y': 'pos', 'sample_id': 2}\n",
      "{'x': 'Titanic is my favorite movie.', 'y': 'pos', 'sample_id': 0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "test_sub = LengthStr(intervals=[0, 2])\n",
    "test_dataset = test_sub.slice_population(dataset, 'x')\n",
    "print(test_dataset[0].dump())\n",
    "print(test_dataset[1].dump())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'x': 'Titanic is my favorite movie.', 'y': 'pos', 'sample_id': 0}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_dataset[1].dump()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "In this section, we show how to use a built-in `SubPopulation` `LMSubPopulation` and define our own `SubPopulation`. Now `textflint` only implements a few `SubPopulations`, and we will supplement more task-specific `SubPopulations` like `Transformations`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
